import argparse

import torch

import model_gnn as model_management
import utils

# %%
parser = argparse.ArgumentParser()
parser.add_argument('--datasets', type=str, default='citation',
                    help='specify multiple datasets in DOUBLE QUOTES separated by commas, e.g., \"cora, dblp\"')
parser.add_argument('--epochs', '-e', type=int, default=2500,
                    help='number of epochs for pre-training')
parser.add_argument('--encoder_dim', type=int, default=256,
                    help='dimension of transformer embeddings')
parser.add_argument('--model_type', type=str, default='gcn',
                    help='type of GNN architecture to use')
parser.add_argument('--save', type=bool, default=True,
                    help='whether to save model after pre-training')
args = parser.parse_args()

# %%
data = utils.create_data_structure(
    datasets=args.datasets.split(','),
    ssl=True,
    model_type='gnn'
)
num_datasets = len(data.datasets)

# %%
stems = [
    {
        'num_node_f': dataset.num_features,
        'num_edge_f': [],
        'num_layer_features': [args.encoder_dim],
        'layer_type': 'lin',
        'act': 'relu',
    }
    for dataset in data.datasets
]

backbone = {
    'num_in_features': stems[0]['num_layer_features'][-1],
    'num_layer_features': [args.encoder_dim] * 3,
    'layer_type': args.model_type,
    'act': 'relu',
    'num_heads':[8]*3,
}

ssl_tasks = ['pairsim']

# %%
encoder = model_management.GNNEncoder(stems=stems, backbone=backbone, device=torch.device('cuda'))
ssl_model = model_management.SSLGNN(encoder=encoder, data=data, ssl_tasks=ssl_tasks)
del encoder

if args.save:
    model_name = 'ssl_gnn' + str(num_datasets) + '_datasets_gnn.pt'
    save_path = utils.create_save_path(model_name)
else:
    save_path = None
ssl_model.pretrain(num_epochs=args.epochs, save_path=save_path)
